{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import os.path as osp\n",
    "import json\n",
    "from nltk.stem import WordNetLemmatizer\n",
    "wordnet_lemmatizer = WordNetLemmatizer()\n",
    "\n",
    "TOK = 'question_tok_concol'\n",
    "TYPE = 'question_type_concol_list'\n",
    "HEADER = 'header_tok'\n",
    "# randomly chosen NONE string\n",
    "NONE = \"te8r2ed\" \n",
    "DEFAULT_TABLENAME = \"DEFAULT_NAME\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def refractor_db_data(data_, table):\n",
    "    rec = {}\n",
    "    for table_name in table['table_names']:\n",
    "        rec[table_name] = []\n",
    "        \n",
    "    for entry in table['column_names']:\n",
    "        if entry[0] < 0:\n",
    "            continue\n",
    "        this_table = table['table_names'][entry[0]]\n",
    "        rec[this_table].append(entry[1])\n",
    "    \n",
    "    res = {}  # val - (column, table)\n",
    "    for k, v in data_.items():\n",
    "        this_table = k.lower().replace('_', \" \")\n",
    "        if this_table not in rec:\n",
    "#             print(this_table, \" not in \", rec.keys())\n",
    "            continue\n",
    "        for row in v:\n",
    "            for idx, val in enumerate(row):\n",
    "                if val is None:\n",
    "                    continue\n",
    "                if type(val) is str and len(val) > 0:\n",
    "                    \n",
    "                    if idx < len(rec[this_table]): \\\n",
    "                        res[val.lower()] = (rec[this_table][idx].split() , this_table.split())\n",
    "                    continue\n",
    "                    \n",
    "                if idx < len(rec[this_table]): \\\n",
    "                    res[str(val)] = (rec[this_table][idx].split() , this_table.split())\n",
    "                \n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "league_1  not in tables.\n",
      "tpch_1  not in tables.\n",
      "manufacturer_1  not in tables.\n",
      "wikipedia_1  not in tables.\n",
      "chbenchmark_1  not in tables.\n",
      "predictions_1  not in tables.\n",
      "triage_1  not in tables.\n",
      "classic_1  not in tables.\n",
      "ship  not in tables.\n",
      "pitchfork_1  not in tables.\n",
      "tpcds_1  not in tables.\n",
      "mlb_1  not in tables.\n",
      "phone  not in tables.\n",
      "gym  not in tables.\n"
     ]
    }
   ],
   "source": [
    "dpath = \"/data/projects/nl2sql/hold_out/test.json\"\n",
    "# dpath = \"/data/projects/nl2sql/datasets/data/dev.json\"\n",
    "tpath = \"/data/projects/nl2sql/datasets/data/tables.json\"\n",
    "datapath = \"/data/projects/nl2sql/models/ruismethod/data/db_data.json\"\n",
    "\n",
    "with open(dpath) as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "with open(tpath) as f:\n",
    "    tables = json.load(f)\n",
    "\n",
    "with open(datapath) as f:\n",
    "    db_data = json.load(f)\n",
    "\n",
    "tables_dict = {}\n",
    "for table in tables:\n",
    "    tables_dict[table['db_id']] = table\n",
    "\n",
    "db_data_dict = {}\n",
    "for db_data_ in db_data:\n",
    "    if db_data_['db_id'] in tables_dict:\n",
    "        table = tables_dict[db_data_['db_id']]\n",
    "#         db_data_dict[db_data_['db_id']] = db_data_['data']\n",
    "#         print(db_data_['db_id'])\n",
    "        db_data_dict[db_data_['db_id']] = refractor_db_data(db_data_['data'], table)\n",
    "    else:\n",
    "        print(db_data_['db_id'], \" not in tables.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def toksEQ(toks1, toks2):\n",
    "    str1 = \" \".join([wordnet_lemmatizer.lemmatize(tok) for tok in toks1])\n",
    "    str2 = \" \".join([wordnet_lemmatizer.lemmatize(tok) for tok in toks2])\n",
    "    return str1 == str2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def isNumber(val):\n",
    "    try:\n",
    "        int(val)\n",
    "        return True\n",
    "    except:\n",
    "        return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def group_header(toks, idx, num_toks, header_toks):\n",
    "    for endIdx in reversed(range(idx+1, num_toks+1)):\n",
    "        sub_toks = toks[idx: endIdx]\n",
    "#         for this_header_tok in header_toks:\n",
    "#             if toksEQ(sub_toks, this_header_tok):\n",
    "#                 return endIdx, sub_toks, this_header_tok\n",
    "        if sub_toks in header_toks:\n",
    "            return endIdx, sub_toks\n",
    "    return idx, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def group_val(toks, idx, num_toks, db_data):\n",
    "    for endIdx in reversed(range(idx+1, num_toks+1)):\n",
    "        sub_toks = toks[idx: endIdx]\n",
    "        key = \" \".join(sub_toks)\n",
    "        if key in db_data:\n",
    "            return endIdx, sub_toks\n",
    "    return idx, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def group_table(toks, idx, num_toks, table_names):\n",
    "    table_toks = [name.split() for name in table_names]\n",
    "    for endIdx in reversed(range(idx+1, num_toks+1)):\n",
    "        sub_toks = toks[idx: endIdx]\n",
    "#         for this_table_tok in table_toks:\n",
    "#             if toksEQ(sub_toks, this_table_tok):\n",
    "#                 return endIdx, sub_toks, this_table_tok\n",
    "        if sub_toks in table_toks:\n",
    "            return endIdx, sub_toks\n",
    "    return idx, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def parse_entry(entry, tables_dict, db_data_dict):\n",
    "    table = tables_dict[entry['db_id']]\n",
    "    db_data = {}\n",
    "    if entry['db_id'] in db_data_dict:\n",
    "        db_data = db_data_dict[entry['db_id']]\n",
    "    question_toks = [tok.lower() for tok in entry['question_toks']]\n",
    "    question_toks_lem = [wordnet_lemmatizer.lemmatize(t) for t in question_toks]\n",
    "    \n",
    "    header_toks = []\n",
    "    col2table = {}\n",
    "    for col in table['column_names']:\n",
    "        this_header_tok = col[1].split()\n",
    "        header_toks.append(this_header_tok)\n",
    "        if col[0] < 0: # \"*\"\n",
    "            col2table[\" \".join(this_header_tok)] = [\"all\"]\n",
    "        else:\n",
    "            col2table[\" \".join(this_header_tok)] = \\\n",
    "                table['table_names'][col[0]].split()\n",
    "   \n",
    "    idx = 0\n",
    "    num_toks = len(question_toks)\n",
    "    tok_concol = []\n",
    "    type_concol = []\n",
    "    \n",
    "    while idx < num_toks:\n",
    "        end_idx, tname = group_table(question_toks_lem, idx, num_toks, table[\"table_names\"])\n",
    "        if tname:\n",
    "            tok_concol.append(question_toks[idx: end_idx])\n",
    "            type_concol.append([\"table\"])\n",
    "            idx = end_idx\n",
    "            continue\n",
    "        \n",
    "        end_idx, header = group_header(question_toks_lem, idx, num_toks, header_toks)\n",
    "        if header:\n",
    "            tok_concol.append(question_toks[idx: end_idx])\n",
    "            type_concol.append(col2table[\" \".join(header)])\n",
    "            idx = end_idx\n",
    "            continue\n",
    "    \n",
    "        end_idx, val = group_val(question_toks_lem, idx, num_toks, db_data)\n",
    "        if val:\n",
    "            tok_concol.append(question_toks[idx: end_idx])\n",
    "            col, table_ = db_data[\" \".join(val)]\n",
    "            if isNumber(\" \".join(val)):\n",
    "                type_concol.append(col + table_ + [\"number\"])\n",
    "            else:\n",
    "                type_concol.append(col + table_)\n",
    "            idx = end_idx\n",
    "            continue\n",
    "    \n",
    "        tok_concol.append([question_toks[idx]])\n",
    "        type_concol.append([NONE])\n",
    "        idx += 1\n",
    "    \n",
    "    entry[TOK] = tok_concol\n",
    "    entry[TYPE] = type_concol\n",
    "    \n",
    "    return entry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for entry in data:\n",
    "    parse_entry(entry, tables_dict, db_data_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with open('test_type.json', 'w') as f:\n",
    "    json.dump(obj=data, fp=f, indent=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
